{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Use Soft Q Learning to Play LunarLander-v2\n",
    "\n",
    "PyTorch version"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "\n",
    "import sys\n",
    "import logging\n",
    "import itertools\n",
    "import copy\n",
    "\n",
    "import numpy as np\n",
    "np.random.seed(0)\n",
    "import pandas as pd\n",
    "import gym\n",
    "import matplotlib.pyplot as plt\n",
    "import torch\n",
    "torch.manual_seed(0)\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "import torch.distributions as distributions\n",
    "\n",
    "logging.basicConfig(level=logging.DEBUG,\n",
    "        format='%(asctime)s [%(levelname)s] %(message)s',\n",
    "        stream=sys.stdout, datefmt='%H:%M:%S')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "01:20:05 [INFO] env: <LunarLander<LunarLander-v2>>\n",
      "01:20:05 [INFO] action_space: Discrete(4)\n",
      "01:20:05 [INFO] observation_space: Box(-inf, inf, (8,), float32)\n",
      "01:20:05 [INFO] reward_range: (-inf, inf)\n",
      "01:20:05 [INFO] metadata: {'render.modes': ['human', 'rgb_array'], 'video.frames_per_second': 50}\n",
      "01:20:05 [INFO] _max_episode_steps: 1000\n",
      "01:20:05 [INFO] _elapsed_steps: None\n",
      "01:20:05 [INFO] id: LunarLander-v2\n",
      "01:20:05 [INFO] entry_point: gym.envs.box2d:LunarLander\n",
      "01:20:05 [INFO] reward_threshold: 200\n",
      "01:20:05 [INFO] nondeterministic: False\n",
      "01:20:05 [INFO] max_episode_steps: 1000\n",
      "01:20:05 [INFO] _kwargs: {}\n",
      "01:20:05 [INFO] _env_name: LunarLander\n"
     ]
    }
   ],
   "source": [
    "env = gym.make('LunarLander-v2')\n",
    "env.seed(0)\n",
    "for key in vars(env):\n",
    "    logging.info('%s: %s', key, vars(env)[key])\n",
    "for key in vars(env.spec):\n",
    "    logging.info('%s: %s', key, vars(env.spec)[key])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "class DQNReplayer:\n",
    "    def __init__(self, capacity):\n",
    "        self.memory = pd.DataFrame(index=range(capacity),\n",
    "                columns=['state', 'action', 'reward', 'next_state', 'done'])\n",
    "        self.i = 0\n",
    "        self.count = 0\n",
    "        self.capacity = capacity\n",
    "\n",
    "    def store(self, *args):\n",
    "        self.memory.loc[self.i] = args\n",
    "        self.i = (self.i + 1) % self.capacity\n",
    "        self.count = min(self.count + 1, self.capacity)\n",
    "\n",
    "    def sample(self, size):\n",
    "        indices = np.random.choice(self.count, size=size)\n",
    "        return (np.stack(self.memory.loc[indices, field]) for field in\n",
    "                self.memory.columns)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "class SQLAgent:\n",
    "    def __init__(self, env):\n",
    "        self.action_n = env.action_space.n\n",
    "        self.gamma = 0.99\n",
    "\n",
    "        self.replayer = DQNReplayer(10000)\n",
    "\n",
    "        self.alpha = 0.02\n",
    "\n",
    "        self.evaluate_net = self.build_net(\n",
    "                input_size=env.observation_space.shape[0],\n",
    "                hidden_sizes=[256, 256], output_size=self.action_n)\n",
    "        self.optimizer = optim.Adam(self.evaluate_net.parameters(), lr=3e-4)\n",
    "        self.loss = nn.MSELoss()\n",
    "\n",
    "    def build_net(self, input_size, hidden_sizes, output_size):\n",
    "        layers = []\n",
    "        for input_size, output_size in zip(\n",
    "                [input_size,] + hidden_sizes, hidden_sizes + [output_size,]):\n",
    "            layers.append(nn.Linear(input_size, output_size))\n",
    "            layers.append(nn.ReLU())\n",
    "        layers = layers[:-1]\n",
    "        model = nn.Sequential(*layers)\n",
    "        return model\n",
    "\n",
    "    def reset(self, mode=None):\n",
    "        self.mode = mode\n",
    "        if self.mode == 'train':\n",
    "            self.trajectory = []\n",
    "            self.target_net = copy.deepcopy(self.evaluate_net)\n",
    "\n",
    "    def step(self, observation, reward, done):\n",
    "        state_tensor = torch.as_tensor(observation,\n",
    "                dtype=torch.float).squeeze(0)\n",
    "        q_div_alpha_tensor = self.evaluate_net(state_tensor) / self.alpha\n",
    "        v_div_alpha_tensor = torch.logsumexp(q_div_alpha_tensor, dim=-1,\n",
    "                keepdim=True)\n",
    "        prob_tensor = (q_div_alpha_tensor - v_div_alpha_tensor).exp()\n",
    "        action_tensor = distributions.Categorical(prob_tensor).sample()\n",
    "        action = action_tensor.item()\n",
    "        if self.mode == 'train':\n",
    "            self.trajectory += [observation, reward, done, action]\n",
    "            if len(self.trajectory) >= 8:\n",
    "                state, _, _, act, next_state, reward, done, _ = \\\n",
    "                        self.trajectory[-8:]\n",
    "                self.replayer.store(state, act, reward, next_state, done)\n",
    "            if self.replayer.count >= 500:\n",
    "                self.learn()\n",
    "        return action\n",
    "\n",
    "    def close(self):\n",
    "        pass\n",
    "\n",
    "    def learn(self):\n",
    "        # replay\n",
    "        states, actions, rewards, next_states, dones = \\\n",
    "                self.replayer.sample(128) # replay transitions\n",
    "        state_tensor = torch.as_tensor(states, dtype=torch.float)\n",
    "        action_tensor = torch.as_tensor(actions, dtype=torch.long)\n",
    "        reward_tensor = torch.as_tensor(rewards, dtype=torch.float)\n",
    "        next_state_tensor = torch.as_tensor(next_states, dtype=torch.float)\n",
    "        done_tensor = torch.as_tensor(dones, dtype=torch.float)\n",
    "\n",
    "        # train\n",
    "        next_q_tensor = self.target_net(next_state_tensor)\n",
    "        next_v_tensor = self.alpha * torch.logsumexp(next_q_tensor / self.alpha, dim=-1)\n",
    "        target_tensor = reward_tensor + self.gamma * (1. - done_tensor) * next_v_tensor\n",
    "        pred_tensor = self.evaluate_net(state_tensor)\n",
    "        q_tensor = pred_tensor.gather(1, action_tensor.unsqueeze(1)).squeeze(1)\n",
    "        loss_tensor = self.loss(q_tensor, target_tensor.detach())\n",
    "        self.optimizer.zero_grad()\n",
    "        loss_tensor.backward()\n",
    "        self.optimizer.step()\n",
    "\n",
    "\n",
    "agent = SQLAgent(env)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "01:20:36 [INFO] ==== train ====\n",
      "01:20:38 [DEBUG] train episode 0: reward = -152.58, steps = 84\n",
      "01:20:39 [DEBUG] train episode 1: reward = -142.24, steps = 67\n",
      "01:20:42 [DEBUG] train episode 2: reward = -141.73, steps = 66\n",
      "01:20:44 [DEBUG] train episode 3: reward = -151.91, steps = 81\n",
      "01:20:46 [DEBUG] train episode 4: reward = -277.57, steps = 82\n",
      "01:20:47 [DEBUG] train episode 5: reward = -160.83, steps = 72\n",
      "01:20:48 [DEBUG] train episode 6: reward = -111.87, steps = 52\n",
      "01:20:51 [DEBUG] train episode 7: reward = -181.23, steps = 76\n",
      "01:20:53 [DEBUG] train episode 8: reward = -485.51, steps = 117\n",
      "01:20:57 [DEBUG] train episode 9: reward = -513.05, steps = 123\n",
      "01:20:59 [DEBUG] train episode 10: reward = -133.94, steps = 64\n",
      "01:21:10 [DEBUG] train episode 11: reward = 249.74, steps = 445\n",
      "01:21:11 [DEBUG] train episode 12: reward = -165.85, steps = 69\n",
      "01:21:13 [DEBUG] train episode 13: reward = -99.38, steps = 56\n",
      "01:21:15 [DEBUG] train episode 14: reward = -123.65, steps = 68\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "\u001b[1;32m<ipython-input-6-938354c7aead>\u001b[0m in \u001b[0;36m<module>\u001b[1;34m\u001b[0m\n\u001b[0;32m     21\u001b[0m \u001b[0mepisode_rewards\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     22\u001b[0m \u001b[1;32mfor\u001b[0m \u001b[0mepisode\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mitertools\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mcount\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 23\u001b[1;33m     episode_reward, elapsed_steps = play_episode(env.unwrapped, agent,\n\u001b[0m\u001b[0;32m     24\u001b[0m             max_episode_steps=env._max_episode_steps, render=1,mode='train')\n\u001b[0;32m     25\u001b[0m     \u001b[0mepisode_rewards\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mepisode_reward\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32m<ipython-input-6-938354c7aead>\u001b[0m in \u001b[0;36mplay_episode\u001b[1;34m(env, agent, max_episode_steps, mode, render)\u001b[0m\n\u001b[0;32m      4\u001b[0m     \u001b[0mepisode_reward\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0melapsed_steps\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;36m0.\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;36m0\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m      5\u001b[0m     \u001b[1;32mwhile\u001b[0m \u001b[1;32mTrue\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m----> 6\u001b[1;33m         \u001b[0maction\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0magent\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mstep\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mobservation\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mreward\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mdone\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m      7\u001b[0m         \u001b[1;32mif\u001b[0m \u001b[0mrender\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m      8\u001b[0m             \u001b[0menv\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mrender\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32m<ipython-input-4-2679525f5962>\u001b[0m in \u001b[0;36mstep\u001b[1;34m(self, observation, reward, done)\u001b[0m\n\u001b[0;32m     46\u001b[0m                 \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mreplayer\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mstore\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mstate\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mact\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mreward\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mnext_state\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mdone\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     47\u001b[0m             \u001b[1;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mreplayer\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mcount\u001b[0m \u001b[1;33m>=\u001b[0m \u001b[1;36m500\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 48\u001b[1;33m                 \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mlearn\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m     49\u001b[0m         \u001b[1;32mreturn\u001b[0m \u001b[0maction\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     50\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32m<ipython-input-4-2679525f5962>\u001b[0m in \u001b[0;36mlearn\u001b[1;34m(self)\u001b[0m\n\u001b[0;32m     64\u001b[0m         \u001b[1;31m# train\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     65\u001b[0m         \u001b[0mnext_q_tensor\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtarget_net\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mnext_state_tensor\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 66\u001b[1;33m         \u001b[0mnext_v_tensor\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0malpha\u001b[0m \u001b[1;33m*\u001b[0m \u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mlogsumexp\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mnext_q_tensor\u001b[0m \u001b[1;33m/\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0malpha\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mdim\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;33m-\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m     67\u001b[0m         \u001b[0mtarget_tensor\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mreward_tensor\u001b[0m \u001b[1;33m+\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mgamma\u001b[0m \u001b[1;33m*\u001b[0m \u001b[1;33m(\u001b[0m\u001b[1;36m1.\u001b[0m \u001b[1;33m-\u001b[0m \u001b[0mdone_tensor\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;33m*\u001b[0m \u001b[0mnext_v_tensor\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     68\u001b[0m         \u001b[0mpred_tensor\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mevaluate_net\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mstate_tensor\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "def play_episode(env, agent, max_episode_steps=None, mode=None, render=False):\n",
    "    observation, reward, done = env.reset(), 0., False\n",
    "    agent.reset(mode=mode)\n",
    "    episode_reward, elapsed_steps = 0., 0\n",
    "    while True:\n",
    "        action = agent.step(observation, reward, done)\n",
    "        if render:\n",
    "            env.render()\n",
    "        if done:\n",
    "            break\n",
    "        observation, reward, done, _ = env.step(action)\n",
    "        episode_reward += reward\n",
    "        elapsed_steps += 1\n",
    "        if max_episode_steps and elapsed_steps >= max_episode_steps:\n",
    "            break\n",
    "    agent.close()\n",
    "    return episode_reward, elapsed_steps\n",
    "\n",
    "\n",
    "logging.info('==== train ====')\n",
    "episode_rewards = []\n",
    "for episode in itertools.count():\n",
    "    episode_reward, elapsed_steps = play_episode(env.unwrapped, agent,\n",
    "            max_episode_steps=env._max_episode_steps, render=1,mode='train')\n",
    "    episode_rewards.append(episode_reward)\n",
    "    logging.debug('train episode %d: reward = %.2f, steps = %d',\n",
    "            episode, episode_reward, elapsed_steps)\n",
    "    if np.mean(episode_rewards[-10:]) > 250:\n",
    "        break\n",
    "plt.plot(episode_rewards)\n",
    "\n",
    "\n",
    "logging.info('==== test ====')\n",
    "episode_rewards = []\n",
    "for episode in range(100):\n",
    "    episode_reward, elapsed_steps = play_episode(env, agent, render=1)\n",
    "    episode_rewards.append(episode_reward)\n",
    "    logging.debug('test episode %d: reward = %.2f, steps = %d',\n",
    "            episode, episode_reward, elapsed_steps)\n",
    "logging.info('average episode reward = %.2f ± %.2f',\n",
    "        np.mean(episode_rewards), np.std(episode_rewards))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
